# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import random
import string
import subprocess

from ..base_constraint_satisfaction_generator import (
    BaseConstraintSatisfactionTestGenerator,
)


class PythonGenerator(BaseConstraintSatisfactionTestGenerator):
    def generate_prompt(self, code: str) -> str:
        prompt = f"""Here's a python code that takes a string and prints out a score between 0 and 1.0. Can you find the string that would output a score of 1.0? Respond in json format with "answer" as the key. Only return the input string without explanation. \n\ncode: {code}"""

        return prompt

    def generate_test_case(self) -> tuple[str, bytes, str]:
        random_seed = "".join(
            random.choices(string.ascii_letters + string.digits, k=20)
        )
        self.calculate_constraints(random_seed)

        python_code = f"""
import string
import sys
import re
import math

def analyze_string(s):
    def is_prime(n):
        if n <= 1 or (n % 2 == 0 and n > 2):
            return False
        return all(n % i for i in range(3, int(math.sqrt(n)) + 1, 2))

    vowels = "aeiouAEIOU"
    consonants = "bcdfghjklmnpqrstvwxyzBCDFGHJKLMNPQRSTVWXYZ"
    vowel_count = sum(1 for char in s if char.lower() in vowels)
    consonant_count = sum(1 for char in s if char.lower() in consonants)
    digit_count = sum(1 for char in s if char.isdigit())
    unique_chars = len(set(s))
    uppercase_count = sum(1 for char in s if char.isupper())
    lowercase_count = sum(1 for char in s if char.islower())
    case_ratio = uppercase_count / lowercase_count if lowercase_count != 0 else -1
    # Additional analyses
    is_palindrome = s == s[::-1]
    numeric_sequence = bool(re.search(r'\\d+', s))
    special_count = sum(1 for char in s if not char.isalnum())
    has_special = any(char in string.punctuation for char in s)
    has_digit = any(char.isdigit() for char in s)
    has_alphabet = any(char.isalpha() for char in s)
    has_whitespace = any(char.isspace() for char in s)
    has_vowel = any(char in vowels for char in s)
    has_consonant = any(char in consonants for char in s)
    has_uppercase = any(char.isupper() for char in s)
    has_lowercase = any(char.islower() for char in s)
    length = len(s)
    length_is_prime = is_prime(len(s))

    score = 0
    score += length == {self.length}
    score += length_is_prime == {self.prime}
    score += has_uppercase == {self.has_uppercase}
    score += has_lowercase == {self.has_lowercase}
    score += has_digit == {self.has_digit}
    score += has_special == {self.has_special}
    score += has_vowel == {self.has_vowel}
    score += has_consonant == {self.has_consonant}
    score += has_whitespace == {self.has_whitespace}
    score += unique_chars == {self.unique_chars}
    score += is_palindrome == {self.is_palindrome}
    score += has_alphabet == {self.has_alphabet}
    score += uppercase_count == {self.uppercase_count}
    score += lowercase_count == {self.lowercase_count}
    score += case_ratio == {self.case_ratio}
    score += numeric_sequence == {self.numeric_sequence}
    score += vowel_count == {self.vowel_count}
    score += consonant_count == {self.consonant_count}
    score += digit_count == {self.digit_count}
    score += special_count == {self.special_count}

    total_case = {self.total_tests}
    return score/total_case

if __name__ == \"__main__\":
    if len(sys.argv) != 2:
        print("Usage: python script.py <candidate_string>")
        sys.exit(1)
    candidate_string = sys.argv[1]
    print(analyze_string(candidate_string))
"""

        return ("python", random_seed.encode(), python_code)

    def compile_and_run(self, code: str, input: bytes) -> str:
        with open("/tmp/test_case.py", "w") as f:
            f.write(code)
        try:
            run_result = subprocess.run(
                ["python3", "/tmp/test_case.py", input.decode()],
                capture_output=True,
                text=True,
            )
        except Exception:
            return "0.0"
        return run_result.stdout

    def score_from_output(self, output: str) -> float:
        return float(output)
